{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4811a7e7-ae06-4930-b4eb-62645888d3dc",
   "metadata": {},
   "source": [
    "# Inference the model on test data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dfce6a9-dad9-4308-a874-fa6e7fe5f6d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install emoji transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48c63d83-657d-4ff5-8b8c-618c62430a31",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28f629c0-df3e-4dbb-8a6d-415eeb5c7f59",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4cf08014-99d6-47f8-b788-02142937dba7",
   "metadata": {},
   "source": [
    "## Read"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39ebb259-b23f-42ab-98f8-6d8b50b0c34d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "data_directory_path = Path(\"../data/\")\n",
    "data_path = data_directory_path / \"train_dataset_train.csv\"\n",
    "val_data_path = data_path #data_directory_path / \"val.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f6e7001-7ebf-43fc-a837-823800db4050",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e6e0fdb-704d-4575-adfc-1c720b7d98ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(data_path, sep=';')\n",
    "val_data = pd.read_csv(val_data_path, sep=';').tail(90)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dec6f3f-0c8a-4f55-a4ca-4c5036a96e39",
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a02a85d-39ba-4ec9-b679-89dda10f337b",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_data = pd.DataFrame(val_data[\"Текст инцидента\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d1f36d8-8954-443c-8f90-aa648e0bd7ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2992d6cc-45cc-4ea1-9188-ef254d6519c3",
   "metadata": {},
   "source": [
    "## Preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a813db4-2262-4624-a498-80907ee2ec5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import preprocess"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f466b8f8-8c69-442c-983f-6b7bc0c02cd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = preprocess.preprocess_data(data)\n",
    "val_data = preprocess.preprocess_data(val_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed4082b2-4513-4dba-91dd-1d439a0a2937",
   "metadata": {},
   "source": [
    "## Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03d9cea7-1c18-4927-b48b-52ab78ca8eff",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.dataset import prepare_data_for_dataset, MessagesDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abd8b935-ebe1-436f-94fd-ea26999a0aef",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import LabelEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ffa28cb-d8c2-4f90-8ac3-6e4290408d9a",
   "metadata": {},
   "outputs": [],
   "source": [
    "executor_encoder = LabelEncoder().fit(data[\"Исполнитель\"])\n",
    "group_encoder    = LabelEncoder().fit(data[\"Группа тем\"])\n",
    "theme_encoder    = LabelEncoder().fit(data[\"Тема\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7b94149-a234-4521-8150-1009b95eaafc",
   "metadata": {},
   "outputs": [],
   "source": [
    "data[\"Исполнитель\"] = executor_encoder.transform(data[\"Исполнитель\"])\n",
    "data[\"Группа тем\"]  = group_encoder.transform(data[\"Группа тем\"])\n",
    "data[\"Тема\"]        = theme_encoder.transform(data[\"Тема\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be561d04-1284-40e6-aef3-8bb88e073e5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b05f6fb6-d721-495c-bf4d-a5fdeed7e3d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, test_executor, _, test_group, _, test_message, _, test_theme = train_test_split(\n",
    "    *prepare_data_for_dataset(data),\n",
    "    test_size=0.01,\n",
    "    random_state=42,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "053151cb-2639-4b87-9dd5-08930f4bd2bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "del _"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "81ed677f-fce6-4ec4-bc21-94506ba346a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#train_dataset = MessagesDataset(train_message, train_executor, train_theme, train_group)\n",
    "test_dataset  = MessagesDataset(test_message, test_executor, test_theme, test_group)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd29e36e-fb86-46f0-a8c8-d5303fa7fe92",
   "metadata": {},
   "outputs": [],
   "source": [
    "class InferenceDataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, messages):\n",
    "        self.messages  = messages\n",
    "\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.messages)\n",
    "\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.messages[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db51eb27-4d70-4411-b850-8a51162efb9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "val_dataset = InferenceDataset(val_data[\"Текст инцидента\"].tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "17438d07-6392-438c-b211-1029d9309407",
   "metadata": {},
   "source": [
    "## Model\n",
    "\n",
    "Используется двуголовая модель."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e897e31c-07ac-497e-bab6-2470f87d8082",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_path = Path(\"../models\")\n",
    "local_weights_filename = \"full_ruroberta_5.pt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dc3c04a-dca8-4046-86eb-86c4c211cc39",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoConfig, AutoModel, AutoTokenizer\n",
    "\n",
    "# \"sberbank-ai/ruRoberta-large\"\n",
    "# \"DeepPavlov/distilrubert-base-cased-conversational\"\n",
    "# \"xlm-roberta-base\"\n",
    "language_model_name = \"sberbank-ai/ruRoberta-large\"\n",
    "\n",
    "language_model = AutoModel.from_config(AutoConfig.from_pretrained(language_model_name))\n",
    "tokenizer = AutoTokenizer.from_pretrained(language_model_name, model_max_length=512)\n",
    "#tokenizer.add_special_tokens({'pad_token': '[PAD]'})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e21a50a0-9189-4704-be55-d03a2d205a1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)\n",
    "test_dataloader  = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)\n",
    "val_dataloader   = torch.utils.data.DataLoader(val_dataset, batch_size=128, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aada3a41-2367-48dc-bb76-6cb64ffb8fec",
   "metadata": {},
   "outputs": [],
   "source": [
    "nllloss = torch.nn.NLLLoss()\n",
    "#classifier_loss = lambda y_pred, y: nllloss(y_pred[0], y[0]) + nllloss(y_pred[1], y[1])\n",
    "classifier_loss = lambda y_pred, y: 0.0 * nllloss(y_pred[0], y[0]) + 1.0 * nllloss(y_pred[1], y[1])\n",
    "#evaluate_classifier(classifier, test_dataloader, classifier_loss, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "47f5dc36-2875-4129-a345-6245712190a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.classifier import Classifier, evaluate_classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fab506af-6824-41be-b968-b2c9bdd86751",
   "metadata": {},
   "outputs": [],
   "source": [
    "classifier = Classifier(language_model, tokenizer, preprocess.theme_to_group(data), executor_encoder, theme_encoder, hid_size=1024).to(device)\n",
    "\n",
    "classifier.load_state_dict(torch.load(models_path / local_weights_filename, map_location=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc57f2fe-70aa-43c0-939d-d4e4967e1fa0",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_classifier(classifier, test_dataloader, classifier_loss, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "837c4ee3-df0d-43ac-b725-56423ada073c",
   "metadata": {},
   "outputs": [],
   "source": [
    "del test_dataloader\n",
    "\n",
    "del test_executor\n",
    "del test_group\n",
    "del test_message\n",
    "del test_theme"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30ddce98-73fd-4ce8-b764-605765699930",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b8a4697-0e47-4a7a-b8e9-4ee124f6fa78",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4b0a5f2b-4976-4bff-a1b1-4f3cce6be4f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference_classifier(classifier, dataloader, device) -> (np.array, np.array, np.array):\n",
    "    \"\"\"\n",
    "    Inference the classifier.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Exit training mode.\n",
    "    was_in_training = classifier.training\n",
    "    classifier.eval()\n",
    "    \n",
    "    # Targets and predictions.\n",
    "    y_pred_all_1 = []\n",
    "    y_pred_all_2 = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for index, message in enumerate(tqdm(dataloader)):\n",
    "            tokens = classifier.tokenizer(message, padding=True, truncation='only_first',\n",
    "                                          return_tensors=\"pt\").to(device)\n",
    "            del message\n",
    "\n",
    "            y_pred_1, y_pred_2 = classifier(tokens)\n",
    "\n",
    "            y_pred_all_1.append(np.argmax(y_pred_1.detach().cpu().numpy(), axis=1))\n",
    "            y_pred_all_2.append(np.argmax(y_pred_2.detach().cpu().numpy(), axis=1))\n",
    "            del y_pred_1\n",
    "            del y_pred_2\n",
    "\n",
    "    y_pred_all_1 = np.concatenate(y_pred_all_1)\n",
    "    y_pred_all_2 = np.concatenate(y_pred_all_2)\n",
    "    y_pred_all_3 = np.vectorize(classifier.theme_to_group.get)(y_pred_all_2)\n",
    "\n",
    "    # Return to the original mode.\n",
    "    classifier.train(was_in_training)\n",
    "    \n",
    "    return y_pred_all_1, y_pred_all_2, y_pred_all_3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ebba7e2-2b26-44c5-9db4-43936cb65ef4",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction = pd.DataFrame()\n",
    "prediction_columns = inference_classifier(classifier, val_dataloader, device)\n",
    "prediction[\"Тема\"] = theme_encoder.inverse_transform(prediction_columns[1])\n",
    "prediction[\"Группа тем\"] = group_encoder.inverse_transform(prediction_columns[2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef95ce7f-8ff8-46c1-8344-4385469eec88",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction_file_name = local_weights_filename + \".csv\"\n",
    "prediction.to_csv(data_directory_path / prediction_file_name, sep=\";\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e551add0-58ac-4d86-9785-ae9b4cab8f51",
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d93e422-87fa-4dd6-b370-3c13666e4929",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
